# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# An implementation of SM3 from:
#
# Memory-Efficient Adaptive Optimization, https://arxiv.org/pdf/1901.11150.pdf
# Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer
#
# Author: Rohan Anil (rohananil at google dot com)
#

"""SM3 Implementation."""

import functools
from typing import Any, NamedTuple

import chex
import jax
import jax.numpy as jnp
import optax

from .quantization_utils import QuantizedValue


class SM3State(NamedTuple):
    count: chex.Array
    stats: Any


# Per parameter optimizer state used in data-parallel training.
class ParameterStats(NamedTuple):
    """State associated to each parameter of the model being trained."""

    diagonal_statistics: chex.Array  # Accumulator for diagonal preconditioner
    diagonal_momentum: QuantizedValue  # Momentum for the diagonal preconditioner


def sm3(
    learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False
):
    """SM3 optimizer.

    Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren,
      Yoram Singer

    https://arxiv.org/abs/1901.11150

    Args:
      learning_rate: the step size used to update the parameters.
      beta1: momentum parameter.
      beta2: second moment averaging parameter.
      diagonal_epsilon: epsilon for sm3
      normalize_grads: Whether to normalize grads. Author finds it useful when
        grads are high variance.

    Returns:
      a GradientTransformation.
    """

    def _quantize_momentum(momentum_statistics):
        return QuantizedValue.from_float_value(momentum_statistics, jnp.int8)

    def init_fn(params):
        """Initialise the optimiser's state."""

        def _init(param):
            accumulators = [jnp.zeros([s]) for s in param.shape]
            momentum = _quantize_momentum(jnp.zeros_like(param))
            return ParameterStats(accumulators, momentum)

        return SM3State(
            count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
        )

    def _get_expanded_shape(shape, i):
        rank = len(shape)
        # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i.
        # For eg: i = 1 returns [1, N, 1].
        return [1] * i + [shape[i]] + [1] * (rank - i - 1)

    def _moving_averages(grad, accumulators):
        w = (1.0 - beta2) if beta2 != 1.0 else 1.0
        if grad.ndim < 2:
            return beta2 * accumulators[0] + w * grad**2
        else:
            min_accumulator = functools.reduce(jnp.minimum, accumulators)
            return beta2 * min_accumulator + w * grad**2

    def _moving_averages_momentum(grad, momentum):
        w = (1.0 - beta1) if beta1 != 1.0 else 1.0
        return beta1 * momentum.to_float() + w * grad

    def _sketch_diagonal_statistics(grad, updated_diagonal_statistics):
        all_diagonal_statistics = []
        for i in range(grad.ndim):
            axes = list(range(i)) + list(range(i + 1, grad.ndim))
            dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes)
            all_diagonal_statistics.append(dim_diagonal_statistics)
        if grad.ndim == 1:
            all_diagonal_statistics[0] = updated_diagonal_statistics
        return all_diagonal_statistics

    def update_fn(updates, state, params=None):
        del params
        stats = state.stats
        if normalize_grads:
            updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates)
        # Reshape all vectors into N-d tensors to compute min over them.
        # [n], [m] -> [n, 1], [1, m]
        expanded_diagonal_statistics = jax.tree_multimap(
            lambda grad, state: [  # pylint:disable=g-long-lambda
                jnp.reshape(
                    state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i)
                )
                for i in range(grad.ndim)
            ],
            updates,
            stats,
        )

        # Compute new diagonal statistics
        new_diagonal_statistics = jax.tree_multimap(
            _moving_averages, updates, expanded_diagonal_statistics
        )

        # Compute preconditioners (1/sqrt(s)) where s is the statistics.
        new_preconditioners = jax.tree_map(
            lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics
        )
        preconditioned_grads = jax.tree_multimap(
            lambda g, p: g * p, updates, new_preconditioners
        )

        # Compute updated momentum (also handle quantization)
        updated_momentum = jax.tree_multimap(
            lambda preconditioned_grad, state: _moving_averages_momentum(  # pylint:disable=g-long-lambda
                preconditioned_grad, state.diagonal_momentum
            ),
            preconditioned_grads,
            stats,
        )

        # Update diagonal statistics.
        updated_diagonal_statistics = jax.tree_multimap(
            _sketch_diagonal_statistics, updates, new_diagonal_statistics
        )

        # Update momentum.
        new_sm3_stats = jax.tree_multimap(
            lambda momentum, diagonal_stats: ParameterStats(  # pylint:disable=g-long-lambda
                diagonal_stats, _quantize_momentum(momentum)
            ),
            updated_momentum,
            updated_diagonal_statistics,
        )

        lr = learning_rate
        if callable(learning_rate):
            lr = learning_rate(state.count)

        new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum)
        return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats)

    return optax.GradientTransformation(init_fn, update_fn)
